from torch.utils.data.dataset import T_co

from data.hintbase_dataset import HintDataset


class AnimehintDataset(HintDataset):
    def __init__(self, opt):
        super(AnimehintDataset, self).__init__(opt)

    def postprocess(self, input_dict):
        image = input_dict['image']
        sketch = input_dict['sketch']
        hints = input_dict['hints']
        params = self.get_params(image.size)
        transform = self.get_transform(params)
        sketch_transform = self.get_sketch_transform(params)
        # transform
        sketch = sketch_transform(sketch)
        image = transform(image)
        hints = transform(hints)

        input_dict.update({'sketch':sketch, 'image':image, 'hints':hints})

        return input_dict

